import torchvision
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

"""
torchvision  获取数据集 等一系列工具
"""

# 类型转换
transform = transforms.Compose([
    transforms.ToTensor()
])

#tensorboard
writer = SummaryWriter("logs05")

# 获取数据集
train_set = torchvision.datasets.MNIST("../dataset/MINIST", transform=transform, download=True, train=True)
test_set = torchvision.datasets.MNIST("../dataset/MINIST", transform=transform, download=True, train=False)


# #未进行PIl转Tensor才可以执行
img, target = test_set[0]
print(test_set.classes)
print(img)
print(target)
print(type(img))
img = img.squeeze(0)
uploader = transforms.ToPILImage()
img = uploader(img)
print(type(img))
img.show()

# 去除数据集前十张
for i in range(10):
    img ,target = test_set[i]
    print(target)
    writer.add_image("MINIST", img, i)

writer.close()